[ExecuTorch][WebGPU] Add 4-bit weight-only quantized linear (et_vk.linear_q4gsw)#20226
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20226
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Pending, 2 Unrelated FailuresAs of commit 8e81436 with merge base 5526971 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 4m 47s —— View job Claude Code Review
Overall this is a clean, well-guarded first cut. Validation-before-allocation, the fp32-only byte-size checks, the Potential bugs / correctness1. Weight bound as The weight is bound as Consider validating 2. The shader indexes 3. A non-multiple input numel would silently drop the remainder into a smaller Performance (non-blocking)4. Scale reload + integer division every K iteration —
Minor / style
None of the above are blockers except items 1–2, which I'd recommend guarding (or covering with an odd-shape test) before relying on this in production, since both fail silently via robustness clamping rather than throwing. |
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
ec31143
into
gh/JulianCloudNTH/23/base
…near_q4gsw) Pull Request resolved: #20226 Adds the `et_vk.linear_q4gsw` operator (4-bit groupwise-symmetric weight-only linear) to the WebGPU backend: dequantize the packed int4 weight in WGSL (`(q-8)*scale`) and accumulate an fp32 matmul, consuming the serialized `[N, K/2]` uint8 weight directly (no prepack), one workgroup per output row. Mirrors the Vulkan reference (`backends/vulkan/.../impl/QuantizedLinear.cpp`). The dispatch carries a `linear_q4gsw` label for GPU-timestamp-query profiling (mirroring the SDPA kernels). The numerical test suite is in the stacked test diff. ghstack-source-id: 392908894 @exported-using-ghexport Differential Revision: [D108312283](https://our.internmc.facebook.com/intern/diff/D108312283/)
…near_q4gsw) (#20262) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20226 by @JulianCloudNTH ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/orig Differential Revision: [D108312283](https://our.internmc.facebook.com/intern/diff/D108312283/) @diff-train-skip-merge --------- Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
Stack from ghstack (oldest at bottom):
Adds the
et_vk.linear_q4gswoperator (4-bit groupwise-symmetric weight-only linear) to the WebGPU backend: dequantize the packed int4 weight in WGSL ((q-8)*scale) and accumulate an fp32 matmul, consuming the serialized[N, K/2]uint8 weight directly (no prepack), one workgroup per output row. Mirrors the Vulkan reference (backends/vulkan/.../impl/QuantizedLinear.cpp). The dispatch carries alinear_q4gswlabel for GPU-timestamp-query profiling (mirroring the SDPA kernels). The numerical test suite is in the stacked test diff.Differential Revision: D108312283